{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Double Q Learning to Play Taxi-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "logging.basicConfig(level=logging.INFO,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] env: <TaxiEnv<Taxi-v3>>\n",
      "00:00:00 [INFO] action_space: Discrete(6)\n",
      "00:00:00 [INFO] observation_space: Discrete(500)\n",
      "00:00:00 [INFO] reward_range: (-inf, inf)\n",
      "00:00:00 [INFO] metadata: {'render.modes': ['human', 'ansi']}\n",
      "00:00:00 [INFO] _max_episode_steps: 200\n",
      "00:00:00 [INFO] _elapsed_steps: None\n",
      "00:00:00 [INFO] id: Taxi-v3\n",
      "00:00:00 [INFO] entry_point: gym.envs.toy_text:TaxiEnv\n",
      "00:00:00 [INFO] reward_threshold: 8\n",
      "00:00:00 [INFO] nondeterministic: False\n",
      "00:00:00 [INFO] max_episode_steps: 200\n",
      "00:00:00 [INFO] _kwargs: {}\n",
      "00:00:00 [INFO] _env_name: Taxi\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v3')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])\n",
    "for key in vars(env.spec):\n",
    "    logging.info('%s: %s', key, vars(env.spec)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DoubleQLearningAgent:\n",
    "    def __init__(self, env):\n",
    "        self.gamma = 0.99\n",
    "        self.learning_rate = 0.1\n",
    "        self.epsilon = 0.01\n",
    "        self.action_n = env.action_space.n\n",
    "        self.qs = [np.zeros((env.observation_space.n, env.action_space.n)) for _ in range(2)]\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        self.mode = mode\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory = []\n",
    "\n",
    "    def step(self, observation, reward, done):\n",
    "        if self.mode == 'train' and np.random.uniform() < self.epsilon:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        else:\n",
    "            action = (self.qs[0] + self.qs[1])[observation].argmax()\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory += [observation, reward, done, action]\n",
    "            if len(self.trajectory) >= 8:\n",
    "                self.learn()\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "    def learn(self):\n",
    "        state, _, _, action, next_state, reward, done, _ = \\\n",
    "                        self.trajectory[-8:]\n",
    "\n",
    "        if np.random.randint(2):\n",
    "            self.qs = self.qs[::-1] # swap two elements\n",
    "        a = self.qs[0][next_state].argmax()\n",
    "        v = reward + self.gamma * self.qs[1][next_state, a] * (1. - done)\n",
    "        target = reward + self.gamma * v * (1. - done)\n",
    "        td_error = target - self.qs[0][state, action]\n",
    "        self.qs[0][state, action] += self.learning_rate * td_error\n",
    "\n",
    "\n",
    "agent = DoubleQLearningAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:01 [INFO] ==== train ====\n",
      "00:00:14 [INFO] ==== test ====\n",
      "00:00:14 [INFO] average episode reward = 7.84 ± 2.32\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAggklEQVR4nO3deXxV9Z3/8dcn+76RhawESADZl4iouLCDWrB2EZ1Wa3WcWq1LZ7QydG9pmU5/dXS0i2OrtdMZSzsdtVZrte1oFwWhVUQUjQRlk1UEgoEs398f9yS5CTcJcG9yc+95Px+PPDj3e8499/u9Ce/7vd/zPeeYcw4REfGXhGhXQEREBp7CX0TEhxT+IiI+pPAXEfEhhb+IiA8lRbsCJ6qwsNBVV1dHuxoiIjFl3bp1e51zRd3LYyb8q6urWbt2bbSrISISU8zsrVDlGvYREfEhhb+IiA8p/EVEfEjhLyLiQwp/EREfUviLiPiQwl9ExIdiZp5/NOw7fJRDTS0sf/hlFo4bysdmDMPMOta3tTl+tnYrackJzBpdzFcf20hFXjqLJ5fTsLeReWNL2LznMH+q38uZI4bwwpZ3uXBiKU9ueIfcjGTGDM0mJSmBne81UZqbxhu7DjMkK4WxpTl8/5nNlOWlUVWQQUpSAuPKcnnwuS2829jMh6aV8/zm/ZxTW8i6t95l446DTB2Wx+wxJV3q39bmeOSl7Wx/931WN+ynblgBT2zYyX9cUUdTcyurG/azaPxQ1jTsZ+7YEt7ef4SjzW00Hmvh6Vd38enza1j1wlYq8tOp332YeeNKGF2Szdv7j7CmYT9DslL43au7uWRqOaW56ZTlpVO/+xAjCrNISOh8n5xzfPf/3uTD0yrIz0jhVy/t4IwRBSQlJLD9wPuMLMqk8Vgr5XnpXZ7zp/q9TKzIIzc9mbY2x+u7D5GUkEBOehLJCQnkZ6aw/cD77D7YxJSqfNa9tZ+/1O+jJCeNxAQjMcHITktizmmd78vBpmYaj7ZQmtv5Wvsbj9Ha5ijMSuEHz27mtNIczhtVxF/e3EuiGWeMGMKWvY3sazzKkMxUNu48SIIZr71zkPNHF3PgyDGmVObTsK+R4uxU1jTs5+Ip5V1+F6/vOkRNUef7cuRYC++810Sbc9TvbmREUSav7jzIyKIsWtsckyrzeP9YK7f+4iWWTC5ncmUeh5qaaWlz1BRlsbphPwebmhlZlMmIwizWb3+P1rY2akuyyUlLBuDFrQdISjBSkxKoyM9g18Em3nu/mey0JH72wlY+t3AMZvDmnsPUFGd31NOAkUVZ/P613fzozw1cd/5IhhVkUpyTSnJiAg17G6kpzuIbj7/K0Jw0SnLSKM5JZWpVPrsPNfHs63tY99a7fGZ2LZUFGQDU7z5MfkYyrc6Bg88/vIE7l04hPSWR+t2HOHy0lT/X72XBuKEcOHKMqVX57G08SnJCAuveepeZtYVs3X+EmuLA+/Pj595iUkUuR1va2LjjIB+cWs6QzBRu+dmLTK7M4+u/fpXLz6jitNIc6oblU1Oc1fF/t6m5lUde3M7kynxGlWSx870mmppbyc9IIT8zhdY2R8PeRpISjCPHWhkzNLvj99bU3Mpj63fyoanlXbIA4LH1OxgzNIcnX3mH684bSUKCceDIMbYfeB/noLowkzbnONTUQuPRFqqHZLLwzmf5+sXjOb26gM17GklKNJ7fvI+P1lXSsLeRUSXZJ5FYJ8eidT1/M1sI3AkkAvc551b2tn1dXZ0b6JO8qm//dY/rplcXsGbL/gGsjYj41WtfW0hacuIpPdfM1jnn6rqXR2XYx8wSgXuARcBY4DIzGxuNuoTinKOtrfcPRQW/iAyUpATre6OT3WfE93hipgP1zrnNAGb2ELAE2DhQFTjW0sa3f7uJ62fVkJ6cyPw7nuFLHxhHm3Nc/WNdRkJEBo+kxMj306MV/uXA1qDH24AzBrICj760g3uf3czvX9tNVmoSW/Yd4fMPb2D7gfcHshoiIiFVFqQzsiiLy6dX9cv+ozXbJ9R3mOPGWczsWjNba2Zr9+zZE9EKtLa1AYEDUS9uPQAEDsLJwPjaxeP73OaskUO4c+nk/q8M8O2PTAIgsR++XvfmXz888YS2e/MbF7Bl5YUdP4/ecDaP33jOCQ8HZKclcd8VdWxZeWGX8pWXTCA3PblL2QUThnZ5HPycN1Ys4oozh5GSFIiOP31uFltWXsib37gAgJy0JJ659XzeWLGIH3x8Gr+5+Zzj6lKQmXJCdQaYWVPI+i/Ppyw3jU+cVc2MEQW9br/ykgkkGNz78WkA/PSaMzh3VOcFLcty0zqWs9NC932D3+dQ7r58CgAfratg9pjiLuumVuWFfM66z8/lpjm1AORnJHfs/9YFozv21W7J5DK2rLyQP942mweums78cUND7jNcUTnga2ZnAl92zi3wHi8DcM59s6fnRPqA70Nr3ub2X77cpSw/I5l3jzRH7DWkZw9cdTqfuP+FHtcvnlTGXZcF/pN1P/D+qxtm8oG7/wTApXWVpCUn8OPnul64sDQ3jW9eMqHjNb70gbHc/ft6Go+1cM3MEXxgUhlPvvIO33nqdW6YVUNtSRY3PfQiF00s5atLxnPgyDFm/79nAPjL7bM5a+XvQ9Zz41cXcNsv1vPY+p0AJCcaL395ARt3HuRwUwub9xxm3rihXHjXH7nn8qmU5aWTk5bEzveaqCnOIi05kaMtrRhGcqLx1MZdDBuSyYJ/e7bjNf7lQxO49PTQvb+2NkdzWxv/+fzb/N0ZVSSYsbphH7/Z8A5fWzKeQ0db2H2widqgWSOPvrSDv771LlOq8lgyOTAryTlHc6vrCPWt+4/w2PqdNDW3csu8UWzZ20hWWhKFWakA7D18lKc27uKyoF5pS2ugQxVqiOJYSxub3jnE/X9p4OY5oyjLS6OlzZGWnEhLaxstbY4259iy9wiNx1r4x1Uvcc/lU5lQkXvcvl7fdYiGvY3MHlPMS1sPMKUqnzf3HKY4O5W8jM4PlqMtraQmJeKcY/ehozQebWFEURYtrW2YGY3ejKv5dwTe6/uuqOOFt/azbNFpHfs4fcXTFGen8vD1Z1O7/Akg8OHQ1NzacQD252u3clZNIcXZqSSYsX7bAQAmVuRxOOj9d86xpmE/04cXdJkptH7bAcYMzen4++k+iyhcPR3wjdawzwtArZkNB7YDS4HLB+rFdx1s4lfrdxxX/n5z60BVwff66nJ8ZfG4juXrZ43knj+82fE4OBDGluXQGOIb20PXzmDYkMyOx1edPZyrzh7eZZvfbHgHALPAf1SAiyaWUZCZ0qV3WhY0BTU3PZn33u/sIGSkJHH35VN5bH3gA+rXN55DWnIiU6vyATp6nS9+cX6X1x7ihShAalLnLI72Xt706gLWvf1uR4+6JwkJRmpCIlfP7GzbObVFnFNb1FHf7j37xZPKWDyprEuZmZGS1Bk6lQUZXHf+yI7H1YWZXbYvzErtEvzQ+7h0SlICEypy+c5HJwdt3/m89uWxZYEQfPa2WT3ua1RJdscUyLrqgo6y7trfVzOjJKezx99ez5y0ZHLSkplSlcff3j7A3LElzB3bdbr0C8vnhqxD8Mybj9RVdlk3xfvdQ9f337xpw921/+0NtKgM+zjnWoAbgCeBV4FVzrlX+vM1X9x6gJ88H+gd/t19q/lz/b7jtmlqbuvPKvjObQtH97wyRPq3D2H88Mo68oPC9yPTKo/b9vllc7hl7ig+NmNYx/kBRdmdgRoc/D1JT0nw/k1keGEmW1ZeyMLxnV+xg4cIJlXmAT2HQbtIzcte9akz+wx+iYyfXH0Gz9x6fp/bPXTtDJ646fhhrFgVtZO8nHOPA48P1OtdfM+fAfj4jGHU7z48UC/rK1eeOazL8Munz69hYnkev3ttF/f/eQtfWTyOLz3a9TP+jOEFrG4ITJtdPLmsS8+wXXCoL78g8JV8aG4aN80NjKEunhTorc+sKeRYaxttQZ/hf7xtFs2toT/Urzyrmqbmti695mB/uPX8jn09eNV0GvY1kpKUwMtfnk9Tcxu7DjZ1bJudlsShJh0zikVZqUlkpfYdhTNC9Npjmc7wlYgpDvpq3W5mbSETKnJpa3NcenplR/i3B/rUYfmcXl3A37a+yxcuDH2qR2ZqEi9+cR4vbHmXed2+lkPg63T7MEfwEArQcYZpKKlJidzoHYTraX273IxkJmfkAZCdlkx2WtcPpSdvPpfNexp73JfIYKPwl1PSfRweAgcNQ8lNT+YrS7rO7hlfnsv/fvosJpTnntAc5ryMlJDBP1iU5aV3OTYgMtgp/OWU1BaHP7YdfGBMRAaW767qGa1rGcWCm+bUcvHksr43hC4Xbmunt1Ykdvgu/N/Qwd4e3TJvFHdcOrnjcWbK8ReSOmtk4KBXcojwP2900XFl3X1wSjl3XDrp1CspIhHhu/BvU/e0V8EnmHzfO0syJTGBX90wk1sXjO6YFRF8HsqTN5/LH2+bxcSKvB7Pimx3x6WT+eCUil63EZH+57vwV/afuPb3avrwAiZU5HL9rJqgtZ3pP3podq+zakRk8PFd+Etof3/O8XPd071hn+ApjfrsFIkPvpvto55/aMtDzLGvG5bPyksmcFHQpQDa378IX35ERAaY73r+n7h/TbSrEHUPfnJ6r+u/eckEPjS1AjNj6fSqkGc/9pb959QW9njFRBEZHHzxP7Qp6IJtuw8djWJNBoczR/Z+mvpl06uOu2hXp0DXv7crD/7k6gG9NYOInAJf9PwvuOuP0a5CVGV4Y/efmR04YJsYFNznjSqitjjrhPeV7J2N2w83FhKRAeSLnr+uuRLwqfNG8o/zu15p88d9DAF19/WLxzNsSCbnjSrue2MRGbR8Ef5+9V/XnEFOejIf+f5zEdvnkKxUbl80BoAbZ9dw+vDe76wkIoOTwj+OnVVT2OVxpGfofHZ+L9frF5FBTSO3MW5oiMsod3fp6YGboSRroF5EPOr5x7gT6c1/8aKxfG7hGIW/iHRQGvhAQoJ1nK0rIgIKfxERX1L4xzhdrkJEToXCX0TEhxT+ceaf5o+KdhVEJAYo/GOcC7rIckFmSi/X5BER6aTwj3Hdx/yHZKWG3lBEJIjCP8bpeK+InAqFf4xI6eEEraQQN1IXEemLwj/GpSR1/gqTE/VBICInRpd3iBV95HpBZgo/vWYGAFfPHM7O994fgEqJSKxS+MeIvvr0N8+tpca7KcsXLjr+frwiIsE07BMjdMN0EYkkhX+MS08OXLAtNUm/ShE5cRr2iXE3z63l9V2H+dDUimhXRURiiMI/RlgPo/7pKUncOKd2gGsjIrFOYwUxQmP+IhJJCv8YoewXkUgKK/zN7F/N7DUzW29m/2tmeUHrlplZvZltMrMFQeXTzOxlb91dZurTngi9TSISSeH2/J8CxjvnJgKvA8sAzGwssBQYBywEvmtm7fcR/B5wLVDr/SwMsw4iInKSwgp/59xvnXMt3sPngfYpJ0uAh5xzR51zDUA9MN3MSoEc59xzzjkHPAhcHE4d/KK93/+RaRWs+oczo1oXEYl9kRzz/yTwhLdcDmwNWrfNKyv3lruXh2Rm15rZWjNbu2fPnghWNXaZwfThBZxTWxjtqohIDOsz/M3saTPbEOJnSdA2y4EW4KftRSF25XopD8k5d69zrs45V1dUVNRXVUNqa4vPix5PqsgDoDhb1+8XkZPX5zx/59zc3tab2ZXARcAcbygHAj36yqDNKoAdXnlFiPJ+88qOg/25+4HT7WPzlnmjWDRhKKeV5kSnPiIS08Kd7bMQ+Byw2Dl3JGjVo8BSM0s1s+EEDuyucc7tBA6Z2Qxvls8VwCPh1KEvew439efuB0z3r0yJCca4styw9nlpXWXfG4lIXAr3DN+7gVTgKW8q4vPOuU85514xs1XARgLDQdc751q951wHPACkEzhG8MRxe42gu35X35+7HzCRnuq5ZeWFEd2fiMSWsMLfOVfTy7oVwIoQ5WuB8eG8rh9pmr+IRJLO8I0xPV3jR0TkZMR9+MdLjzlOmiEig0Tch3+8GDM0MKsnXj7MRCS6FP6D3PjyHJ665VwWTy6LdlVEJI7EffjHQ0e5tiSb/IwUAEpy0qJcGxGJB3F/M5d4uRrmgnEl3H35FBaMGxrtqohIHIj78O886Tg2tVffzLhoooZ+RCQy4n7YJ9YUZKZEuwoi4gNxH/6xNOxz4YRSLpxQ2qUsxr+4iMggFf/hH+0KeNYsn9PnNgkJg6W2IhLv4j78B0vHuTg79CydSRWdF2ebP7aE2pKsgaqSiPhY3If/YLdkcue9bC6aWMrHZwyLYm1ExC/iPvxjbSDFzJhUmdfxeLB8cxGR+BL/4T/I01/hLiLREPfhH0s6ZiYFTfGJ9fMURGRwUvhH2SD/YiIicUrhHwWvf31Rx3LIfv1gH6sSkZin8I+ClKSub/s1M4dHqSYi4lcK/yhzzvH5i8Z2vadulzH/KFRKROKewl9ExIcU/oNR0Jj/hKAzgEVEIiXuL+k82G94PrY0p8d1X10yjo/WVQ5gbUTEL+K/5z+4s5+zagqPL/QG+ieU55KWnDjANRIRP4j/8NcBUxGR48R/+Eex5/+LT515ak/UPH8R6WdxH/7RjNGa4sDlmUfpMs0iMsjE/QHfaHntawtJS07k9a8vov0eLZMr80hOVK9eRKJP4d9P2g/UBp/N+/D1Z0erOiIiXcT9sI+IiBwv7sNfx05FRI4X9+EvIiLHU/gPQp88uxqA6iGZ0a2IiMStuD/gO9gv7xDKksnlXW7sLiISaer5i4j4UETC38z+ycycmRUGlS0zs3oz22RmC4LKp5nZy966u8xi75DskMyUaFdBRCQsYYe/mVUC84C3g8rGAkuBccBC4Ltm1n6Fsu8B1wK13s/CcOvQe/0iv8+F44eGLM9Ji/tRNBGJE5Ho+d8B3EbXS6gtAR5yzh11zjUA9cB0MysFcpxzzznnHPAgcHEE6tCj/rgTVnJi6Lft6pkjIv9iIiL9IKzwN7PFwHbn3EvdVpUDW4Meb/PKyr3l7uX9xg3gZT1jbwBLRPyqz3EKM3saCDXOsRz4Z2B+qKeFKHO9lPf02tcSGCKiqqqqr6qG3kcMzvYREelvffb8nXNznXPju/8Am4HhwEtmtgWoAP5qZkMJ9OiDb0FVAezwyitClPf02vc65+qcc3VFRUUn2zagf3rjnzx7eMfyv1825bj15XnpJ7QffVMQkWg55WEf59zLzrli51y1c66aQLBPdc69AzwKLDWzVDMbTuDA7hrn3E7gkJnN8Gb5XAE8En4zBlbVkIyO5Q9MKjtu/SVT+x7JWv3Pc/jr5+dFtF4iIieqX6anOOdeMbNVwEagBbjeOdfqrb4OeABIB57wfvrNX97cF5H9DMlMYV/jsYjsC6AkJy1i+xIROVkRO8nL+wawN+jxCufcSOfcaOfcE0Hla72ho5HOuRu8WT+DxoTy3JDldy4NDO+c6JCOiMhgpjN8uxnI2UEiItGi8BcR8SGdktqDb14yga37j/Dd/3sTgLK8NK6eOZzLpgcmMd1/1els3tPYsW1uejL1uw9Hrb4iIidD4d9NUkLgy9D4slwum17F4y/vZMu+I5gZX7hobMd2s0YXM2t0YPmy6YFzEO75Qz0AMXi5IhHxGYV/N3dfPoWfPPcW48pyTvq5V51dza6DTfzDubrMg4gMbgr/biryM1h2wWkdj9tvxH4iffmMlCS+umR8P9VMRCRyFP4EzrTtacLpfVfW8fDftjMs6MQuEZFYp9k+QFZqz5+BFfkZ3DC7VuP4IhJXfB/+ty0czcpLJka7GiIiA8r34f/p82vIz0yOdjVERAaU78NfRMSPFP7omv8i4j8KfxERH1L4i4j4kK/Df9miMYDuqCUi/uPr8K8uzIx2FUREosLX4T+4biMjIjJwfB3+IiJ+pfAXEfEhhT8ndsVOEZF4ovAPUpKTGu0qiIgMCIU/nXfeqirQZZtFxB8U/iIiPuTz8A/M9SzODgz3TBtWEM3KiIgMGN3Ji8DJXk9/9jyqdbcuEfEJhb+npjgr2lUQERkwvhr2efGL81j/5fnRroaISNT5quefl5HS5bEu7yAifuWrnr+IiAT4Ovw1zi8ifuXr8K8tyY52FUREosK34X/JlPJoV0FEJGp8dcC33RsrFpGo23eJiI/5MvyTE337hUdEBPDxsI+IiJ+FHf5m9hkz22Rmr5jZt4LKl5lZvbduQVD5NDN72Vt3l5nGX0REBlpYwz5mNgtYAkx0zh01s2KvfCywFBgHlAFPm9ko51wr8D3gWuB54HFgIfBEOPUQEZGTE27P/zpgpXPuKIBzbrdXvgR4yDl31DnXANQD082sFMhxzj3nnHPAg8DFYdZBREROUrjhPwo4x8xWm9kzZna6V14ObA3abptXVu4tdy8PycyuNbO1ZrZ2z549YVZVRETa9TnsY2ZPA0NDrFruPT8fmAGcDqwysxGEvi2u66U8JOfcvcC9AHV1dboSj4hIhPQZ/s65uT2tM7PrgF96QzhrzKwNKCTQo68M2rQC2OGVV4QoFxGRARTusM/DwGwAMxsFpAB7gUeBpWaWambDgVpgjXNuJ3DIzGZ4s3yuAB4Jsw4iInKSwj3J60fAj8xsA3AMuNL7FvCKma0CNgItwPXeTB8IHCR+AEgnMMtHM31ERAZYWOHvnDsGfKyHdSuAFSHK1wLjw3ldEREJj87wFRHxIYW/iIgPKfxFRHxI4S8i4kMKfxERH1L4i4j4kMJfRMSHFP4iIj6k8BcR8SGFv4iID/km/EeVZEW7CiIig4Zvwt9C3kpARMSffBP+IiLSSeEvIuJDCn8RER9S+IuI+JBvwr/V6f7vIiLtfBP+R1ta+95IRMQnfBP+6viLiHTyTfiLiEgnhb+IiA/5Jvw17CMi0sk34S8iIp0U/iIiPqTwFxHxIYW/iIgP+Sb8nY74ioh08E34m+l6/iIi7XwT/ur5i4h08k34i4hIJ4W/iIgP+Sb8NeYvItLJN+EvIiKdFP4iIj6k8BcR8aGwwt/MJpvZ82b2opmtNbPpQeuWmVm9mW0yswVB5dPM7GVv3V02QIPxmuopItIp3J7/t4CvOOcmA1/0HmNmY4GlwDhgIfBdM0v0nvM94Fqg1vtZGGYdRETkJIUb/g7I8ZZzgR3e8hLgIefcUedcA1APTDezUiDHOfecC3TFHwQuDrMOIiJykpLCfP7NwJNm9m0CHyRneeXlwPNB223zypq95e7lIZnZtQS+JVBVVRVWRTXoIyLSqc/wN7OngaEhVi0H5gC3OOf+x8w+CvwQmAuEGsd3vZSH5Jy7F7gXoK6uTvktIhIhfYa/c25uT+vM7EHgJu/hz4H7vOVtQGXQphUEhoS2ecvdy0VEZACFO+a/AzjPW54NvOEtPwosNbNUMxtO4MDuGufcTuCQmc3wZvlcATwSZh1EROQkhTvm//fAnWaWBDThjc87514xs1XARqAFuN451+o95zrgASAdeML76Xea6Ski0ims8HfO/QmY1sO6FcCKEOVrgfHhvK6IiIRHZ/iKiPiQwl9ExId8E/5OM/1FRDr4JvxFRKSTb8Jfs31ERDr5JvxFRKSTwl9ExIcU/iIiPhT34V+Rnw7Aig9OiHJNREQGj7gP/5TEQBNHFGVGuSYiIoNH3Ie/iIgcT+EvIuJDCn8RER+K+/BPSw7cNz7ULcRERPwq3Ov5D3r3XjGNX/51O8MLdcBXRKRd3Id/RX4GN86pjXY1REQGlbgf9hERkeMp/EVEfEjhLyLiQwp/EREfUviLiPiQwl9ExIcU/iIiPqTwFxHxIXMxcnNbM9sDvHWKTy8E9kawOoNFvLYL4rdt8douiN+2xXq7hjnniroXxkz4h8PM1jrn6qJdj0iL13ZB/LYtXtsF8du2eG2Xhn1ERHxI4S8i4kN+Cf97o12BfhKv7YL4bVu8tgvit21x2S5fjPmLiEhXfun5i4hIEIW/iIgPxXX4m9lCM9tkZvVmdnu069MXM6s0sz+Y2atm9oqZ3eSVF5jZU2b2hvdvftBzlnnt22RmC4LKp5nZy966u8ws6neyNLNEM/ubmT3mPY6XduWZ2S/M7DXvd3dmPLTNzG7x/g43mNl/m1larLbLzH5kZrvNbENQWcTaYmapZvYzr3y1mVUPaANPhXMuLn+AROBNYASQArwEjI12vfqocykw1VvOBl4HxgLfAm73ym8H/sVbHuu1KxUY7rU30Vu3BjiTwO2LnwAWDYL2fRb4L+Ax73G8tOvHwDXecgqQF+ttA8qBBiDde7wK+ESstgs4F5gKbAgqi1hbgE8D3/eWlwI/i/bfZZ/vSbQr0I+/7DOBJ4MeLwOWRbteJ9mGR4B5wCag1CsrBTaFahPwpNfuUuC1oPLLgB9EuS0VwO+A2XSGfzy0K8cLSetWHtNt88J/K1BA4HavjwHzY7ldQHW38I9YW9q38ZaTCJwRbP3Vlkj8xPOwT/sfb7ttXllM8L42TgFWAyXOuZ0A3r/F3mY9tbHcW+5eHk3/BtwGtAWVxUO7RgB7gPu9Ia37zCyTGG+bc2478G3gbWAn8J5z7rfEeLu6iWRbOp7jnGsB3gOG9FvNIyCewz/UuGJMzGs1syzgf4CbnXMHe9s0RJnrpTwqzOwiYLdzbt2JPiVE2aBrlyeJwHDC95xzU4BGAkMIPYmJtnnj30sIDHuUAZlm9rHenhKibNC16wSdSltirp3xHP7bgMqgxxXAjijV5YSZWTKB4P+pc+6XXvEuMyv11pcCu73yntq4zVvuXh4tZwOLzWwL8BAw28z+k9hvFwTqtM05t9p7/AsCHwax3ra5QINzbo9zrhn4JXAWsd+uYJFsS8dzzCwJyAX291vNIyCew/8FoNbMhptZCoGDMI9GuU698mYO/BB41Tn3naBVjwJXestXEjgW0F6+1JtpMByoBdZ4X2EPmdkMb59XBD1nwDnnljnnKpxz1QR+D793zn2MGG8XgHPuHWCrmY32iuYAG4n9tr0NzDCzDK8+c4BXif12BYtkW4L39WECf+ODuucf9YMO/fkDXEBgxsybwPJo1+cE6juTwFfF9cCL3s8FBMYOfwe84f1bEPSc5V77NhE0iwKoAzZ46+5mkBx8As6n84BvXLQLmAys9X5vDwP58dA24CvAa16dfkJg9ktMtgv4bwLHLpoJ9NKvjmRbgDTg50A9gRlBI6L9d9nXjy7vICLiQ/E87CMiIj1Q+IuI+JDCX0TEhxT+IiI+pPAXEfEhhb+IiA8p/EVEfOj/A+VzB/PMM+dpAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== train ====')\n",
    "episode_rewards = []\n",
    "for episode in itertools.count():\n",
    "    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n",
    "            max_episode_steps=env._max_episode_steps, mode='train')\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('train episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "    if np.mean(episode_rewards[-200:]) > 8:\n",
    "        break\n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
